{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "基本原理和离散动作是一样的,连续动作的概率使用高斯密度函数计算即可."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAR8AAAEXCAYAAACUBEAgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAc0ElEQVR4nO3dfXBTdb4/8Hee29ImoZQmdGmlu6DYy4NL0RKduc5v6VK16/qAv6sMox3l6qiBAXGYtbuCo7Mz5Ycz68quojM7K9w/pDs4W11Z0O2vaFnHWKBQLU9d9y5uu5SkPNikLW2SJp/7h/ZcghWTNuk3oe/XzJmx53zSvgPk7ek5OSc6EREQEU0wveoARDQ5sXyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJZeXzyiuvYNasWcjKykJFRQUOHDigKgoRKaCkfP7whz9g/fr1eO6553D48GEsXLgQVVVV6OnpURGHiBTQqbiwtKKiAjfeeCN++9vfAgCi0SiKi4uxZs0aPPPMM9/5+Gg0iu7ubuTl5UGn06U6LhHFSUTQ19eHoqIi6PVX3rcxTlAmTSgUQmtrK2pra7V1er0elZWV8Hg8oz4mGAwiGAxqX58+fRplZWUpz0pEY9PV1YWZM2decWbCy+fcuXOIRCJwOBwx6x0OB06ePDnqY+rq6vD8889/Y31XVxesVmtKchJR4gKBAIqLi5GXl/edsxNePmNRW1uL9evXa1+PPEGr1cryIUpD8RwOmfDyKSgogMFggM/ni1nv8/ngdDpHfYzFYoHFYpmIeEQ0QSb8bJfZbEZ5eTmampq0ddFoFE1NTXC5XBMdh4gUUfJr1/r161FTU4PFixfjpptuwq9//WsMDAzg4YcfVhGHiBRQUj73338/zp49i02bNsHr9eKGG27Ae++9942D0ER09VLyPp/xCgQCsNls8Pv9POBMlEYSeW3y2i4iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEiJhMtn//79uPPOO1FUVASdToe33347ZruIYNOmTZgxYways7NRWVmJzz//PGbmwoULWLlyJaxWK+x2O1atWoX+/v5xPREiyiwJl8/AwAAWLlyIV155ZdTtW7ZswdatW/Haa6+hpaUFU6ZMQVVVFYaGhrSZlStX4tixY2hsbMTu3buxf/9+PPbYY2N/FkSUeWQcAEhDQ4P2dTQaFafTKS+++KK2rre3VywWi+zcuVNERI4fPy4A5ODBg9rM3r17RafTyenTp+P6uX6/XwCI3+8fT3wiSrJEXptJPeZz6tQpeL1eVFZWautsNhsqKirg8XgAAB6PB3a7HYsXL9ZmKisrodfr0dLSMur3DQaDCAQCMQsRZbaklo/X6wUAOByOmPUOh0Pb5vV6UVhYGLPdaDQiPz9fm7lcXV0dbDabthQXFyczNhEpkBFnu2pra+H3+7Wlq6tLdSQiGqeklo/T6QQA+Hy+mPU+n0/b5nQ60dPTE7N9eHgYFy5c0GYuZ7FYYLVaYxYiymxJLZ/S0lI4nU40NTVp6wKBAFpaWuByuQAALpcLvb29aG1t1Wb27duHaDSKioqKZMYhojRmTPQB/f39+Pvf/659ferUKbS1tSE/Px8lJSVYt24dfvnLX2LOnDkoLS3Fxo0bUVRUhLvvvhsAcP311+O2227Do48+itdeew3hcBirV6/GAw88gKKioqQ9MSJKc4meSvvggw8EwDeWmpoaEfnqdPvGjRvF4XCIxWKRpUuXSkdHR8z3OH/+vKxYsUJyc3PFarXKww8/LH19fXFn4Kl2ovSUyGtTJyKisPvGJBAIwGazwe/38/gPURpJ5LWZEWe7iOjqw/IhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpERC5VNXV4cbb7wReXl5KCwsxN13342Ojo6YmaGhIbjdbkybNg25ublYvnw5fD5fzExnZyeqq6uRk5ODwsJCbNiwAcPDw+N/NkSUMRIqn+bmZrjdbnzyySdobGxEOBzGsmXLMDAwoM089dRTePfdd7Fr1y40Nzeju7sb9957r7Y9EomguroaoVAIH3/8MXbs2IHt27dj06ZNyXtWRJT+ZBx6enoEgDQ3N4uISG9vr5hMJtm1a5c2c+LECQEgHo9HRET27Nkjer1evF6vNrNt2zaxWq0SDAbj+rl+v18AiN/vH098IkqyRF6b4zrm4/f7AQD5+fkAgNbWVoTDYVRWVmozc+fORUlJCTweDwDA4/Fg/vz5cDgc2kxVVRUCgQCOHTs26s8JBoMIBAIxCxFltjGXTzQaxbp163DLLbdg3rx5AACv1wuz2Qy73R4z63A44PV6tZlLi2dk+8i20dTV1cFms2lLcXHxWGMTUZoYc/m43W4cPXoU9fX1ycwzqtraWvj9fm3p6upK+c8kotQyjuVBq1evxu7du7F//37MnDlTW+90OhEKhdDb2xuz9+Pz+eB0OrWZAwcOxHy/kbNhIzOXs1gssFgsY4lKRGkqoT0fEcHq1avR0NCAffv2obS0NGZ7eXk5TCYTmpqatHUdHR3o7OyEy+UCALhcLrS3t6Onp0ebaWxshNVqRVlZ2XieCxFlkIT2fNxuN95880288847yMvL047R2Gw2ZGdnw2azYdWqVVi/fj3y8/NhtVqxZs0auFwuLFmyBACwbNkylJWV4cEHH8SWLVvg9Xrx7LPPwu12c++GaDJJ5DQagFGXN954Q5sZHByUJ598UqZOnSo5OTlyzz33yJkzZ2K+zxdffCG33367ZGdnS0FBgTz99NMSDofjzsFT7UTpKZHXpk5ERF31jU0gEIDNZoPf74fValUdh4i+lshrk9d2EZESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREmO6qp0oGSQaRWRgANFQCDqDAYbsbOjMZuh0OtXRaAKwfGjCiQjCFy7g7N698B84gNC5c9BbLMiZPRuFd9yBvAULoDMYVMekFGP50IQSEQS7u/HFyy9joKMD+PrSwkh/P/znz6O/vR1FDz2E6VVVLKCrHI/50ISKXLyIztdfx8DJk5BoFF8Ggzh07hw+DwQQFUHk4kWc/q//gv/wYWTgNc+UAO750IQREfhbWtD32WcQEXQODGDjkSPo8PsxxWjEf157Le4vLQUuXoSvoQF58+bBkJ2tOjalCPd8aMKEfD6c3rEDiEYhAP5fezuO9/YiIoJAOIzfnjiBo19+CQC4+I9/IDI4qDYwpRTLhyZMX3s7wpd87FEgHI7ZHopGEYxEJjoWKcLyoQkhIgieOQN8XS46AP/H6YTxktPq11qtuCY3V1FCmmg85kMTQsJhDP3rX9rXOp0ONbNnI89kwv8/cwYzsrPx6LXXojArCwBgyMqCTs//N17NWD40IaKhEPpPnoxZZ9Tr8X9nzcJ9s2ZhZP9n5A2GufPmwci9oKsay4cmxHAgABnleI5Op8No72c25uVBZ+Q/z6sZ92tpQvS1tyPS3x/fsMEA66JFqQ1EyrF8KOVEBMN9fdq7mb+LTqdD1owZKU5FqrF8KOUkEkFfW1vc85YZM2DIy0tdIEoLLB9KORkeRvCSj8f+LlnFxTCyfK56LB9KuaDXi8jAQNzzU669lqfZJwH+DVNKiQgGv/gCkb6++B6g02HKddelNhSlBZYPpdzwJZdUfBfDlCkw2WwpTEPpguVDqSWC/qNH4x7PKi6GubAwhYEoXbB8KKVkePira7riZLRaoTOZUpiI0gXLh1IqdPYswn5/3PP2iooUpqF0wvKhlBm5kn24tzfux5gLCngD+UmC5UMpFfT54p41FRQg63vfS2EaSicsH0op/6FDcc+abDaY8vNTmIbSCcuHUkZCoa+u6YpTVnExwDcXThr8m6aUCfb0INjdHfe89YYbUheG0g7Lh1Jm2O+P+zYaOoMBhrw8HmyeRBIqn23btmHBggWwWq2wWq1wuVzYu3evtn1oaAhutxvTpk1Dbm4uli9fDt9lBxw7OztRXV2NnJwcFBYWYsOGDRgeHk7Os6G0MvD553HPmgoKkHv99SlMQ+kmofKZOXMmNm/ejNbWVhw6dAg/+tGPcNddd+HYsWMAgKeeegrvvvsudu3ahebmZnR3d+Pee+/VHh+JRFBdXY1QKISPP/4YO3bswPbt27Fp06bkPitSTkQwcNltU6/ENHUq9F/fv5kmB52M82Mh8/Pz8eKLL+K+++7D9OnT8eabb+K+++4DAJw8eRLXX389PB4PlixZgr179+InP/kJuru74XA4AACvvfYafvazn+Hs2bMwm82j/oxgMIhgMKh9HQgEUFxcDL/fD6vVOp74lCKRwUH8/fnn0X/8eFzzMx54ADNWrOCvXRkuEAjAZrPF9doc8zGfSCSC+vp6DAwMwOVyobW1FeFwGJWVldrM3LlzUVJSAo/HAwDweDyYP3++VjwAUFVVhUAgoO09jaaurg42m01biouLxxqbJkj4wgVc/O//jnveNG1aCtNQOkq4fNrb25GbmwuLxYLHH38cDQ0NKCsrg9frhdlsht1uj5l3OBzwer0AAK/XG1M8I9tHtn2b2tpa+P1+benq6ko0Nk2w4b4+ROM8lmfIzUXu3Lnc65lkEv54gOuuuw5tbW3w+/146623UFNTg+bm5lRk01gsFlgslpT+DEquwJEj2gcEfhe92QxTQUGKE1G6Sbh8zGYzZs+eDQAoLy/HwYMH8fLLL+P+++9HKBRCb29vzN6Pz+eD0+kEADidThw4cCDm+42cDRuZocwnIgidPRv3fM4PfgA9r2SfdMb9Pp9oNIpgMIjy8nKYTCY0NTVp2zo6OtDZ2QmXywUAcLlcaG9vR88l9/NtbGyE1WpFWVnZeKNQmohevIjBzs6453O+/33eRmMSSmjPp7a2FrfffjtKSkrQ19eHN998Ex9++CHef/992Gw2rFq1CuvXr0d+fj6sVivWrFkDl8uFJUuWAACWLVuGsrIyPPjgg9iyZQu8Xi+effZZuN1u/lp1FYkMDmIo3vLR62GePp3HeyahhMqnp6cHDz30EM6cOQObzYYFCxbg/fffx49//GMAwEsvvQS9Xo/ly5cjGAyiqqoKr776qvZ4g8GA3bt344knnoDL5cKUKVNQU1ODF154IbnPipQKnT0LiUbjmtVbLMidPz/FiSgdjft9Piok8l4CmnjdO3fizM6dcc0a7XaUbd0K02VnSSkzTcj7fIhGI9Fo/B+LDMC6cCGMubkpTETpiuVDSRUdGkIggU8nNRcWAgZD6gJR2mL5UFJFw+G4PypHZzAgb8ECHmyepFg+lFSD//gHokND8Q0bDDDzzYWTFsuHkkZEMNjZieglFwFfSfasWTDyQPOkxfKh5BFB+Msv4x63TJ8OQ05OCgNROmP5UNJIJIL+K9yd4HLZ3/8+j/dMYiwfSpro0BDCFy7EN6zTIe/f/i21gSitsXwoaYb+9S8Mx/keH0N2Ngx5eSlOROmM5UNJISIYOn0a0cHBuOazr7kGlsvu7USTC8uHkmbo9Om4Zw25udB/y21zaXJg+VByRKPo+/TTuMftX9/pgCYvlg8lReTiRUQuXox73lxYmMI0lAlYPpQUg52dcd+90DRtGiy8c+Wkx/KhcRMRDPf2QsLhuOZN+fm8rIJYPpQcfQm8uTC3rAzQ85/eZMd/AZQUg6dOxT2b84MfpDAJZQqWD43bcCCAYb8/rll9VhYshYW8rIJYPjR+Ia837vf4GO12ZF1zTYoTUSZg+dC4iQgQ563AjVYr9PykEgLLhyaYrbwcOt42lcDyoSQwTJkS10WiOpPpqw8I5PEeAsuHkiDre99D/r//O/AdpZJbVoa8BQsmKBWlO5YPjZ9Ohxn/8R+wLlr0rQWUPWsWZj7yCPRZWRMcjtIVy4fGTafTwWi345o1azC9uhqm/PyvPg5Hp4MhLw/2m29G6dNPI3vWLP7KRRp+YikljYgA0SiCPh+CZ85AolGYCwqQNXMmdEYji2cSSOS1mdBntRNdiU6nAwwGZBUVIauoSHUcSnP8tYuIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkxLjKZ/PmzdDpdFi3bp22bmhoCG63G9OmTUNubi6WL18On88X87jOzk5UV1cjJycHhYWF2LBhA4aHh8cThYgyzJjL5+DBg3j99dex4LKrlJ966im8++672LVrF5qbm9Hd3Y17771X2x6JRFBdXY1QKISPP/4YO3bswPbt27Fp06axPwsiyjwyBn19fTJnzhxpbGyUW2+9VdauXSsiIr29vWIymWTXrl3a7IkTJwSAeDweERHZs2eP6PV68Xq92sy2bdvEarVKMBgc9ecNDQ2J3+/Xlq6uLgEgfr9/LPGJKEX8fn/cr80x7fm43W5UV1ejsrIyZn1rayvC4XDM+rlz56KkpAQejwcA4PF4MH/+fDgcDm2mqqoKgUAAx77l41fq6upgs9m0pbi4eCyxiSiNJFw+9fX1OHz4MOrq6r6xzev1wmw2w263x6x3OBzwer3azKXFM7J9ZNtoamtr4ff7taWrqyvR2ESUZhK6qr2rqwtr165FY2MjsibwplAWiwUW3nSc6KqS0J5Pa2srenp6sGjRIhiNRhiNRjQ3N2Pr1q0wGo1wOBwIhULo7e2NeZzP54Pz68/mdjqd3zj7NfK1k5/fTTRpJFQ+S5cuRXt7O9ra2rRl8eLFWLlypfbfJpMJTU1N2mM6OjrQ2dkJl8sFAHC5XGhvb0dPT48209jYCKvVirKysiQ9LSJKdwn92pWXl4d58+bFrJsyZQqmTZumrV+1ahXWr1+P/Px8WK1WrFmzBi6XC0uWLAEALFu2DGVlZXjwwQexZcsWeL1ePPvss3C73fzVimgSSfqdDF966SXo9XosX74cwWAQVVVVePXVV7XtBoMBu3fvxhNPPAGXy4UpU6agpqYGL7zwQrKjEFEa4z2ciShpEnlt8touIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZYPESnB8iEiJVg+RKQEy4eIlGD5EJESLB8iUoLlQ0RKsHyISAmWDxEpwfIhIiVYPkSkBMuHiJRg+RCREiwfIlKC5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhICZYPESnB8iEiJYyqA4yFiAAAAoGA4iREdKmR1+TIa/RKMrJ8zp8/DwAoLi5WnISIRtPX1webzXbFmYwsn/z8fABAZ2fndz7BdBMIBFBcXIyuri5YrVbVceLG3BMrU3OLCPr6+lBUVPSdsxlZPnr9V4eqbDZbRv3FXMpqtWZkduaeWJmYO94dAh5wJiIlWD5EpERGlo/FYsFzzz0Hi8WiOkrCMjU7c0+sTM2dCJ3Ec06MiCjJMnLPh4gyH8uHiJRg+RCREiwfIlKC5UNESmRk+bzyyiuYNWsWsrKyUFFRgQMHDijNs3//ftx5550oKiqCTqfD22+/HbNdRLBp0ybMmDED2dnZqKysxOeffx4zc+HCBaxcuRJWqxV2ux2rVq1Cf39/SnPX1dXhxhtvRF5eHgoLC3H33Xejo6MjZmZoaAhutxvTpk1Dbm4uli9fDp/PFzPT2dmJ6upq5OTkoLCwEBs2bMDw8HDKcm/btg0LFizQ3v3rcrmwd+/etM48ms2bN0On02HdunUZlz0pJMPU19eL2WyW3//+93Ls2DF59NFHxW63i8/nU5Zpz5498otf/EL++Mc/CgBpaGiI2b5582ax2Wzy9ttvy6effio//elPpbS0VAYHB7WZ2267TRYuXCiffPKJ/PWvf5XZs2fLihUrUpq7qqpK3njjDTl69Ki0tbXJHXfcISUlJdLf36/NPP7441JcXCxNTU1y6NAhWbJkidx8883a9uHhYZk3b55UVlbKkSNHZM+ePVJQUCC1tbUpy/2nP/1J/vznP8vf/vY36ejokJ///OdiMpnk6NGjaZv5cgcOHJBZs2bJggULZO3atdr6TMieLBlXPjfddJO43W7t60gkIkVFRVJXV6cw1f+6vHyi0ag4nU558cUXtXW9vb1isVhk586dIiJy/PhxASAHDx7UZvbu3Ss6nU5Onz49Ydl7enoEgDQ3N2s5TSaT7Nq1S5s5ceKEABCPxyMiXxWvXq8Xr9erzWzbtk2sVqsEg8EJyz516lT53e9+lxGZ+/r6ZM6cOdLY2Ci33nqrVj6ZkD2ZMurXrlAohNbWVlRWVmrr9Ho9Kisr4fF4FCb7dqdOnYLX643JbLPZUFFRoWX2eDyw2+1YvHixNlNZWQm9Xo+WlpYJy+r3+wH8710DWltbEQ6HY7LPnTsXJSUlMdnnz58Ph8OhzVRVVSEQCODYsWMpzxyJRFBfX4+BgQG4XK6MyOx2u1FdXR2TEciMP+9kyqir2s+dO4dIJBLzBw8ADocDJ0+eVJTqyrxeLwCMmnlkm9frRWFhYcx2o9GI/Px8bSbVotEo1q1bh1tuuQXz5s3TcpnNZtjt9itmH+25jWxLlfb2drhcLgwNDSE3NxcNDQ0oKytDW1tb2mYGgPr6ehw+fBgHDx78xrZ0/vNOhYwqH0odt9uNo0eP4qOPPlIdJS7XXXcd2tra4Pf78dZbb6GmpgbNzc2qY11RV1cX1q5di8bGRmRlZamOo1xG/dpVUFAAg8HwjaP/Pp8PTqdTUaorG8l1pcxOpxM9PT0x24eHh3HhwoUJeV6rV6/G7t278cEHH2DmzJnaeqfTiVAohN7e3itmH+25jWxLFbPZjNmzZ6O8vBx1dXVYuHAhXn755bTO3Nraip6eHixatAhGoxFGoxHNzc3YunUrjEYjHA5H2mZPhYwqH7PZjPLycjQ1NWnrotEompqa4HK5FCb7dqWlpXA6nTGZA4EAWlpatMwulwu9vb1obW3VZvbt24doNIqKioqUZRMRrF69Gg0NDdi3bx9KS0tjtpeXl8NkMsVk7+joQGdnZ0z29vb2mPJsbGyE1WpFWVlZyrJfLhqNIhgMpnXmpUuXor29HW1tbdqyePFirFy5UvvvdM2eEqqPeCeqvr5eLBaLbN++XY4fPy6PPfaY2O32mKP/E62vr0+OHDkiR44cEQDyq1/9So4cOSL//Oc/ReSrU+12u13eeecd+eyzz+Suu+4a9VT7D3/4Q2lpaZGPPvpI5syZk/JT7U888YTYbDb58MMP5cyZM9py8eJFbebxxx+XkpIS2bdvnxw6dEhcLpe4XC5t+8ip32XLlklbW5u89957Mn369JSe+n3mmWekublZTp06JZ999pk888wzotPp5C9/+UvaZv42l57tyrTs45Vx5SMi8pvf/EZKSkrEbDbLTTfdJJ988onSPB988IEA+MZSU1MjIl+dbt+4caM4HA6xWCyydOlS6ejoiPke58+flxUrVkhubq5YrVZ5+OGHpa+vL6W5R8sMQN544w1tZnBwUJ588kmZOnWq5OTkyD333CNnzpyJ+T5ffPGF3H777ZKdnS0FBQXy9NNPSzgcTlnuRx55RK655hoxm80yffp0Wbp0qVY86Zr521xePpmUfbx4Px8iUiKjjvkQ0dWD5UNESrB8iEgJlg8RKcHyISIlWD5EpATLh4iUYPkQkRIsHyJSguVDREqwfIhIif8Bn1TlQxuFi24AAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('Pendulum-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(\n",
    "            [action * 2])\n",
    "        over = terminated or truncated\n",
    "\n",
    "        #偏移reward,便于训练\n",
    "        reward = (reward + 8) / 8\n",
    "\n",
    "        #限制最大步数\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            over = True\n",
    "\n",
    "        return state, reward, over\n",
    "\n",
    "    #打印游戏图像\n",
    "    def show(self):\n",
    "        from matplotlib import pyplot as plt\n",
    "        plt.figure(figsize=(3, 3))\n",
    "        plt.imshow(self.env.render())\n",
    "        plt.show()\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()\n",
    "\n",
    "env.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((tensor([[ 0.1318],\n",
       "          [-0.0188]], grad_fn=<TanhBackward0>),\n",
       "  tensor([[1.0605],\n",
       "          [0.8742]], grad_fn=<ExpBackward0>)),\n",
       " tensor([[0.1309],\n",
       "         [0.0797]], grad_fn=<AddmmBackward0>))"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "#定义模型\n",
    "class Model(torch.nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.s = torch.nn.Sequential(\n",
    "            torch.nn.Linear(3, 64),\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.Linear(64, 64),\n",
    "            torch.nn.ReLU(),\n",
    "        )\n",
    "        self.mu = torch.nn.Sequential(\n",
    "            torch.nn.Linear(64, 1),\n",
    "            torch.nn.Tanh(),\n",
    "        )\n",
    "        self.sigma = torch.nn.Sequential(\n",
    "            torch.nn.Linear(64, 1),\n",
    "            torch.nn.Tanh(),\n",
    "        )\n",
    "\n",
    "    def forward(self, state):\n",
    "        state = self.s(state)\n",
    "\n",
    "        return self.mu(state), self.sigma(state).exp()\n",
    "\n",
    "\n",
    "model_action = Model()\n",
    "\n",
    "model_value = torch.nn.Sequential(\n",
    "    torch.nn.Linear(3, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 1),\n",
    ")\n",
    "\n",
    "model_action(torch.randn(2, 3)), model_value(torch.randn(2, 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\appDir\\python3.10\\lib\\site-packages\\gym\\utils\\passive_env_checker.py:233: DeprecationWarning: `np.bool8` is a deprecated alias for `np.bool_`.  (Deprecated NumPy 1.24)\n",
      "  if not isinstance(terminated, (bool, np.bool8)):\n",
      "C:\\Users\\Administrator\\AppData\\Local\\Temp\\ipykernel_11748\\3994207745.py:34: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ..\\torch\\csrc\\utils\\tensor_new.cpp:248.)\n",
      "  state = torch.FloatTensor(state).reshape(-1, 3)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "-2.640882968902588"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "import random\n",
    "\n",
    "\n",
    "#玩一局游戏并记录数据\n",
    "def play(show=False):\n",
    "    state = []\n",
    "    action = []\n",
    "    reward = []\n",
    "    next_state = []\n",
    "    over = []\n",
    "\n",
    "    s = env.reset()\n",
    "    o = False\n",
    "    while not o:\n",
    "        #根据概率采样\n",
    "        mu, sigma = model_action(torch.FloatTensor(s).reshape(1, 3))\n",
    "        a = random.normalvariate(mu=mu.item(), sigma=sigma.item())\n",
    "\n",
    "        ns, r, o = env.step(a)\n",
    "\n",
    "        state.append(s)\n",
    "        action.append(a)\n",
    "        reward.append(r)\n",
    "        next_state.append(ns)\n",
    "        over.append(o)\n",
    "\n",
    "        s = ns\n",
    "\n",
    "        if show:\n",
    "            display.clear_output(wait=True)\n",
    "            env.show()\n",
    "\n",
    "    state = torch.FloatTensor(state).reshape(-1, 3)\n",
    "    action = torch.FloatTensor(action).reshape(-1, 1)\n",
    "    reward = torch.FloatTensor(reward).reshape(-1, 1)\n",
    "    next_state = torch.FloatTensor(next_state).reshape(-1, 3)\n",
    "    over = torch.LongTensor(over).reshape(-1, 1)\n",
    "\n",
    "    return state, action, reward, next_state, over, reward.sum().item()\n",
    "\n",
    "\n",
    "state, action, reward, next_state, over, reward_sum = play()\n",
    "\n",
    "reward_sum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer_action = torch.optim.Adam(model_action.parameters(), lr=5e-4)\n",
    "optimizer_value = torch.optim.Adam(model_value.parameters(), lr=5e-3)\n",
    "\n",
    "\n",
    "def requires_grad(model, value):\n",
    "    for param in model.parameters():\n",
    "        param.requires_grad_(value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([200, 1])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def train_value(state, reward, next_state, over):\n",
    "    requires_grad(model_action, False)\n",
    "    requires_grad(model_value, True)\n",
    "\n",
    "    #计算target\n",
    "    with torch.no_grad():\n",
    "        target = model_value(next_state)\n",
    "    target = target * 0.98 * (1 - over) + reward\n",
    "\n",
    "    #每批数据反复训练10次\n",
    "    for _ in range(10):\n",
    "        #计算value\n",
    "        value = model_value(state)\n",
    "\n",
    "        loss = torch.nn.functional.mse_loss(value, target)\n",
    "        loss.backward()\n",
    "        optimizer_value.step()\n",
    "        optimizer_value.zero_grad()\n",
    "\n",
    "    #减去value相当于去基线\n",
    "    return (target - value).detach()\n",
    "\n",
    "\n",
    "value = train_value(state, reward, next_state, over)\n",
    "\n",
    "value.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-0.2696605920791626"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def train_action(state, action, value):\n",
    "    requires_grad(model_action, True)\n",
    "    requires_grad(model_value, False)\n",
    "\n",
    "    #计算当前state的价值,其实就是Q(state,action),这里是用蒙特卡洛法估计的\n",
    "    delta = []\n",
    "    for i in range(len(value)):\n",
    "        s = 0\n",
    "        for j in range(i, len(value)):\n",
    "            s += value[j] * (0.9 * 0.9)**(j - i)\n",
    "        delta.append(s)\n",
    "    delta = torch.FloatTensor(delta).reshape(-1, 1)\n",
    "\n",
    "    #更新前的动作概率\n",
    "    with torch.no_grad():\n",
    "        mu, sigma = model_action(state)\n",
    "        prob_old = torch.distributions.Normal(mu, sigma).log_prob(action).exp()\n",
    "\n",
    "    #每批数据反复训练10次\n",
    "    for _ in range(10):\n",
    "        #更新后的动作概率\n",
    "        mu, sigma = model_action(state)\n",
    "        prob_new = torch.distributions.Normal(mu, sigma).log_prob(action).exp()\n",
    "\n",
    "        #求出概率的变化\n",
    "        ratio = prob_new / prob_old\n",
    "\n",
    "        #计算截断的和不截断的两份loss,取其中小的\n",
    "        surr1 = ratio * delta\n",
    "        surr2 = ratio.clamp(0.8, 1.2) * delta\n",
    "\n",
    "        loss = -torch.min(surr1, surr2).mean()\n",
    "\n",
    "        #更新参数\n",
    "        loss.backward()\n",
    "        optimizer_action.step()\n",
    "        optimizer_action.zero_grad()\n",
    "\n",
    "    return loss.item()\n",
    "\n",
    "\n",
    "train_action(state, action, value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.35754355788230896 46.55806384086609\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    model_action.train()\n",
    "    model_value.train()\n",
    "\n",
    "    #训练N局\n",
    "    for epoch in range(100):\n",
    "        #一个epoch最少玩N步\n",
    "        steps = 0\n",
    "        while steps < 200:\n",
    "            state, action, reward, next_state, over, _ = play()\n",
    "            steps += len(state)\n",
    "\n",
    "            #训练两个模型\n",
    "            delta = train_value(state, reward, next_state, over)\n",
    "            loss = train_action(state, action, delta)\n",
    "\n",
    "        if epoch % 100 == 0:\n",
    "            test_result = sum([play()[-1] for _ in range(20)]) / 20\n",
    "            print(epoch, loss, test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAR8AAAEXCAYAAACUBEAgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAlmklEQVR4nO3dfVxUdb4H8M8M88DjDAIyAxuIpaasD1toMNpaeyVxozZX11Wva1SmVwVXY7ebtGZr211Md3uwUmvbzXq1ZYt7dc2HklCxDEVJVoWkh1UhYQZTmQGEAWa+94+WuU2JCs5wQD/v12teL+f8fuec77HOx3PO78w5KhEREBF1M7XSBRDRtYnhQ0SKYPgQkSIYPkSkCIYPESmC4UNEimD4EJEiGD5EpAiGDxEpguFDRIpQLHxefPFFJCQkIDAwEMnJySguLlaqFCJSgCLh8/bbbyM7OxuPP/44Pv74Y4wYMQJpaWmora1VohwiUoBKiR+WJicnY9SoUXjhhRcAAG63G3FxcViwYAEWL158yfndbjeqq6sRFhYGlUrl73KJ6DKJCOrr6xEbGwu1+uLHNppuqsmjpaUFJSUlyMnJ8UxTq9VITU1FUVHRBedxOp1wOp2e76dOnUJiYqLfayWirqmqqsJ111130T7dHj5fffUVXC4XTCaT13STyYRjx45dcJ7c3FwsW7bsO9OrqqpgMBj8UicRdZ7D4UBcXBzCwsIu2bfbw6crcnJykJ2d7fnevoEGg4HhQ9QDXc7lkG4Pn6ioKAQEBMBms3lNt9lsMJvNF5xHr9dDr9d3R3lE1E26fbRLp9MhKSkJBQUFnmlutxsFBQWwWCzdXQ4RKUSR067s7GxkZGRg5MiRuOWWW/Dss8+isbER999/vxLlEJECFAmfqVOn4vTp01i6dCmsVit+8IMf4N133/3ORWgiunopcp/PlXI4HDAajbDb7bzgTNSDdGbf5G+7iEgRDB8iUgTDh4gUwfAhIkUwfIhIEQwfIlIEw4eIFMHwISJFMHyISBEMHyJSBMOHiBTB8CEiRTB8iEgRDB8iUgTDh4gUwfAhIkUwfIhIEQwfIlIEw4eIFMHwISJFMHyISBEMHyJSBMOHiBTB8CEiRTB8iEgRDB8iUgTDh4gUwfAhIkUwfIhIEQwfIlIEw4eIFMHwISJFMHyISBEMHyJSBMOHiBTR6fDZs2cP7r77bsTGxkKlUmHTpk1e7SKCpUuXIiYmBkFBQUhNTcVnn33m1efs2bOYMWMGDAYDwsPDMWvWLDQ0NFzRhhBR79Lp8GlsbMSIESPw4osvXrB9xYoVWLVqFdauXYv9+/cjJCQEaWlpaG5u9vSZMWMGysrKkJ+fjy1btmDPnj2YM2dO17eCiHofuQIAZOPGjZ7vbrdbzGazrFy50jOtrq5O9Hq9vPXWWyIiUl5eLgDkwIEDnj7bt28XlUolp06duqz12u12ASB2u/1KyiciH+vMvunTaz7Hjx+H1WpFamqqZ5rRaERycjKKiooAAEVFRQgPD8fIkSM9fVJTU6FWq7F///4LLtfpdMLhcHh9iKh382n4WK1WAIDJZPKabjKZPG1WqxXR0dFe7RqNBhEREZ4+35abmwuj0ej5xMXF+bJsIlJArxjtysnJgd1u93yqqqqULomIrpBPw8dsNgMAbDab13SbzeZpM5vNqK2t9Wpva2vD2bNnPX2+Ta/Xw2AweH2IqHfzafj0798fZrMZBQUFnmkOhwP79++HxWIBAFgsFtTV1aGkpMTTZ+fOnXC73UhOTvZlOUTUg2k6O0NDQwM+//xzz/fjx4+jtLQUERERiI+Px6JFi/Dkk09i4MCB6N+/Px577DHExsZi4sSJAIAhQ4ZgwoQJmD17NtauXYvW1lZkZWVh2rRpiI2N9dmGEVEP19mhtF27dgmA73wyMjJE5Ovh9scee0xMJpPo9XoZN26cVFRUeC3jzJkzMn36dAkNDRWDwSD333+/1NfXX3YNHGon6pk6s2+qREQUzL4ucTgcMBqNsNvtvP5D1IN0Zt/sFaNdRHT1YfgQkSIYPkSkCIYPESmC4UNEimD4EJEiGD5EpAiGDxEpguFDRIpg+BCRIhg+RKQIhg8RKYLhQ0SKYPgQkSIYPkSkCIYPESmC4UNEimD4EJEiGD5EpAiGDxEpguFDRIro9Hu7iK52IgJndTXqiorQ+MUXAICQAQMQbrFAHxMDlUqlcIVXB4YP0b+JCKS1FV8VFMC2YQNaTp/2tNXt3YvT27fje/feiz633gqVmicNV4rhQ4Svg8fV2IjqN97A6ffeA1yu7/Rpqa1F1csvQx0UBOPIkTwCukKMbyIAzupqnHj2WZx+990LBk+7NocD1W+8AVdjYzdWd3Vi+NA1TUTQVFWFL3JzYS8uBtzuS87TdOIEGo8d64bqrm487aJrlojg/Gef4cRzz6G5qqozM0IucnREl4fhQ9ckcbvhOHQIlWvWoKW2VulyrkkMH7rmiNuN2nfeQc3bb8PV0NDp+XV9+yL4+uv9UNm1heFD1xR3aytq33kH1X/9K6S1tdPzq7RamCZOhDYqyg/VXVsYPnTNcJ0/j5q330btO+9A2to6vwCVClGpqYgaP57D7D7A8KFrgtNmw5d//jPqDhy46FD6xYRbLIidORNqvd7H1V2bGD50VRMRtNTW4l8rV+L8p592aRkqnQ59J0xAzLRp0ISG+rjCaxfDh65aIoLmykqceO45nP/88y4tQx0YiNiZM9F3wgSotVofV3htY/hQtxERrw8AqFQqqNVqn19DERE0lJXh5IsvwnnqVJeWERAcjPj58/lbLj/p1N9obm4uRo0ahbCwMERHR2PixImoqKjw6tPc3IzMzExERkYiNDQUkydPhs1m8+pTWVmJ9PR0BAcHIzo6Gg8//DDaunIBkHo0EYHb7UZjYyOOHj2Kv/3tb1i2bBkyMjIwYcIEWCwWzJs3D01NTb5dr9uNr/Lz8UVubteDJzQU/RYsQJ8f/pDB4yedOvIpLCxEZmYmRo0ahba2Njz66KMYP348ysvLERISAgB46KGHsHXrVuTl5cFoNCIrKwuTJk3C3r17AQAulwvp6ekwm8346KOPUFNTg3vvvRdarRa///3vfb+F1O1EBC0tLfj000+xZcsWvPvuu6ioqEBDQwMCAgIQGBiIPn36ICwsDIGBgT496nG3tuJMfj6+fPVVuJ3OLi0j+IYbEDdnDkIGD+aolh+ppP34twtOnz6N6OhoFBYWYuzYsbDb7ejbty/efPNN/OxnPwMAHDt2DEOGDEFRURFSUlKwfft23HXXXaiurobJZAIArF27Fo888ghOnz4NnU53yfU6HA4YjUbY7XYYDIaulk8+JiJobm7Gvn378Morr+D9999HU1MT4uPjMXr0aCQnJyMxMRHx8fEIDg5GQEAANBoNgoKCfLKTu5qaYNu0Cda//x3S0tKlZYQNG4b4zEw+t6eLOrNvXtE1H7vdDgCIiIgAAJSUlKC1tRWpqamePoMHD0Z8fLwnfIqKijBs2DBP8ABAWloa5s2bh7KyMtx0003fWY/T6YTzG/+KORyOKymb/MDtdqO8vBx/+MMfsHnzZmg0Gtxxxx2YPn06kpOTER4eDo1G47cduvXcOVT96U+o27eva/fwAAj7wQ/QPzsbGqORwdMNuhw+brcbixYtwpgxYzB06FAAgNVqhU6nQ3h4uFdfk8kEq9Xq6fPN4Glvb2+7kNzcXCxbtqyrpZIfiQiamprw+uuvY/ny5airq8Pdd9+NBQsW4KabbvJr4LSvv62uDseffhr1//xnl5ah0mgQdccdiPnP/4TWaPRxhdSRLodPZmYmjh49ig8//NCX9VxQTk4OsrOzPd8dDgfi4uL8vl66OBFBbW0tli5dijfeeAM33ngjnn32WUyYMAF6vd7vRw8iAmdNDU6uWoWG8vIuLUOl0yFm6lSYfvIT3jzYzboUPllZWdiyZQv27NmD6667zjPdbDajpaUFdXV1Xkc/NpsNZrPZ06e4uNhree2jYe19vk2v10PP/zF6FBHBiRMnsGDBAuzatQs///nP8dvf/hbx8fHdcsoiIjj/6ac4uXo1mo4f79Iy1IGBuO6BBxCVmgqVhneddLdOjSGKCLKysrBx40bs3LkT/fv392pPSkqCVqtFQUGBZ1pFRQUqKythsVgAABaLBUeOHEHtNx5jkJ+fD4PBgMTExCvZFuom7cEzZ84c7N27Fzk5OXj++efRr1+/7gketxvnPvgAnz/5ZJeDJyAkBHGzZ3/9Oy0GjyI69beemZmJN998E//4xz8QFhbmuUZjNBoRFBQEo9GIWbNmITs7GxERETAYDFiwYAEsFgtSUlIAAOPHj0diYiJmzpyJFStWwGq1YsmSJcjMzOTRTS8gIrBarcjKykJJSQmeeOIJzJkz57JGKX2y/rY2nP3gA1S99BJc5893aRmB8fGI/6//Quj3v897eBTUqaH2jv5Ve/XVV3HfffcB+Pomw1/96ld466234HQ6kZaWhtWrV3udUp08eRLz5s3D7t27ERISgoyMDCxfvhyay/wXiEPtymlsbMQvf/lL5OXlYcmSJVi4cCF0Ol23HPG4W1pQu3Urat56C+7m5i4tI2TQIPRbuBCB113HES0/6My+eUX3+SiF4aMMt9uN559/Hr/5zW+QkZGBlStX+uwenUtpa2xE1Usv4dzevV16Dg8AhAwejOsfeQTaiAgGj590230+dO0QERw6dAgrVqzATTfdhCVLlnRL8IgIXA0NOLFqFez793dtIQEBiLz9dsTOnAltnz4Mnh6C4UOXpbGxEb///e/R0tKC3/3udzCbzd0SPC21tahcswaOjz/u0jJUGg1MEyfCPGUKAoKCfFwhXQmGD12SiGDHjh3YsWMH5s6di1tvvbVbgqfpX/9C5Usvdfk1NSqtFrG/+AWi77qLj8PogRg+dEkOh8MzaDBv3rzLHhjoKhGBo6QEJ59/Hq3nznVpGergYMROn47ou+6CKiDAxxWSLzB86KJEBHv37kVxcTGys7ORkJDg3/W5XKgrLkbliy+irYu/4dPHxCB+7lyEjRjBofQejOFDF9Xa2or169fDYDBg6tSpfj3dcre14Ux+Pk699lqX7+EJSkhAwqJFCOrfnxeWeziGD11UdXU1du/ejbFjx2LAgAF+26FdTU348s9/xpndu7v8OIyghARcv3gxH4fRSzB8qEMigo8++gh1dXW46667/HKtR0TgOn8eVS+/jLO7dnVtIWo1+oweje9lZED/jScmiAgcDgcMBgPDqAfiCTF1yOVy4YMPPoDRaERKSopfnrPcevYsTj7/PM7u3t21hQQEoO+ddyJ+/nyv4AG+ft7UqlWr0NzFu6HJvxg+1CGn04mDBw9i0KBBiImJ8emyPW+WeOYZ1BUVAV240V4VEICYKVNw3X33feeVNiKCffv2Yd26dTh58qSvyiYfYvhQh2w2G6qrqzFixAif/nBURNBQXo7Pli1D/eHDXQoedXAwzFOnIubnP4f6ArW53W5s2LABX375JQoKCtALf0V01WP4UIe+/PJLNDU1YeDAgVD7aMhaRNB0/DiOr1iB1q++6tIytFFR6P/QQ4iZMqXDx2FUVVWhoKAAbW1t2Lp1q8/fkEFXjuFDHaqtrYXL5fJ6YNyVcp0/jy/XrUPruXMQEZxzOnHwq6/wmcMB92UcnehjYnD9ww/DeMstHd48KCJ4//33UVVVBeDrZ4v/61//4tFPD8PRLurQ+fPnISKIjIz02cVm+4EDqD98GCKCysZGPHboECrsdoRoNHhw0CBM7d8fAR2sSx8bi+sfeQTB33qI3bc1NTUhLy/PEzbnzp3D+++/j8TERI569SA88qEOiQjUajVCffh+8obycsDthgB46sgRlNfVwSUCR2srXvjkExy90M8pVCoYb7kFA5YuvWTwAEBZWRkOHDjg+e5yubBlyxaOevUwDB+6qPbXGfuD41vP5Wlxu+F0ubw7qdWITE1Fwi9/icDY2Esu0+12Y9OmTdDpdBg7diw0Gg2GDh2K8vJyHDt2jKdePQjDhy7K7Xb79Iih/bEWKgA/Mpuh+cZp0CCDAf2+eZSlViP67rsRN3s2AsLCLmv5Z86cQVVVFV5//XVMmjQJOp0OK1euxH333YePPvrIZ9tBV47XfKhDGo0GIuJ5OaQvRP7oR/gqPx+uhgZkDBiAMK0W79fUICYoCLMHDUJ0YCAAQB0UhL4TJiB2xowLDqVfzO9+9zvExcXh4MGD0Gg0iI+Px9KlS1HexdfrkH8wfKhD4eHhUKvVsNlsEBGfXKwNjI9HzPTpOPXaa9C0tGBKQgJ+lpCA9iWrVCpojEbEPfgg+tx6a6cfhxEVFeX585dffonQ0FDPO+Ev9DZcUg7DhzoUExMDrVaLkydP+ix8VGo1+k6YAACw/f3vaD13Dqp/X4dRabUIGTQIMVOnImz48C49DqO9RpfLhS+++AImkwnBwcFebdQzMHyoQ3FxcTAajSgrK4PL5fLZhWe1Vovo9HQYb74ZjkOH4LRaoQ4KQuiQIQgdMsQnjzs9c+YMKisrMWzYME/4UM/C8KEOGQwGDB48GIcPH0ZDQwP69Onjs2Wr1GoEfu97CPze93y2zG86ceIEampq8Itf/KLb3ilGncPRLuqQVqvFmDFjcOrUqV41TN3+KJCWlhaMGTOGp1s9FMOHOqRSqTB27Fi43W7s3r2714SP0+nEe++9h+uvvx5DhgxRuhzqAMOHOqRSqTB06FAMGTIE77zzDhoaGnp8AIkIKioqcPDgQfzHf/yH1+gX9SwMH7qosLAw3HPPPSgrK8MHH3ygdDmX5Ha7kZeXh5aWFkyZMgUBfHNFj8XwoYtSq9WYOHEijEYj/vSnP/Xo30eJCE6ePIm3334bFosFSUlJSpdEF8HwoUu6/vrrMWXKFOzatQs7d+7ssadeLpcL69atg81mw5w5cxDEN5T2aAwfuiSNRoPZs2cjMjISTz31FM6cOaN0Sd8hIigtLcVf/vIXjB07FuPHj+coVw/H8KFLUqlUGDRoELKysnDw4EGsXbsWbW1tSpflxeFw4Mknn0Rrayv++7//GyEhIUqXRJfA8KHLolar8cADD2Ds2LFYtWpVj3kusoigra0Na9aswfvvv4/Zs2dj9OjRPOrpBRg+dNmMRiNyc3MRFhaGX//61ygrK1M8gEQEmzdvxh//+EfceuutWLhwod/fJU++wfChy6ZSqTBixAg89dRTsFqtyMzMxBdffKFYALXf/JidnQ2z2Yw//OEPvK+nF2H4UKeo1Wrcc889+O1vf4vDhw9jzpw5qKio6PYAcrlc2LFjB+bMmQONRoMXXniBz2juZRg+1GlarRYPPvggHn/8cZSWlmLGjBnYu3cvXN9+BKqfNDc347XXXsODDz4IlUqFl19+GT/84Q/99rhX8g/+16Iu0el0mD9/Pp5++mnU1NRg2rRpeOmll1BfX++3oyARwalTp7B48WI89NBDMJlMeP311/GjH/2IwdMbSSesXr1ahg0bJmFhYRIWFiYpKSmybds2T3tTU5PMnz9fIiIiJCQkRCZNmiRWq9VrGSdPnpQ777xTgoKCpG/fvvLrX/9aWltbO1OG2O12ASB2u71T85HvtbW1yc6dOyUpKUmCgoJk0qRJsm/fPmltbRW32+2TdbjdbmlsbJQNGzbIyJEjJSgoSKZMmSKff/65z9ZBvtGZfbNT4bN582bZunWrfPrpp1JRUSGPPvqoaLVaOXr0qIiIzJ07V+Li4qSgoEAOHjwoKSkpMnr0aM/8bW1tMnToUElNTZVDhw7Jtm3bJCoqSnJycvy2geR/brdbKisrZcGCBRIRESEmk0nmz58vxcXF0tzcLG63u9Mh4Xa7xeVySV1dnWzevFnS09MlNDRUEhIS5LnnnhO73c7g6YH8Fj4X0qdPH3nllVekrq5OtFqt5OXledo++eQTASBFRUUiIrJt2zZRq9VeR0Nr1qwRg8EgTqezw3U0NzeL3W73fKqqqhg+PYzb7Ran0ym7du2Sn/70p2I0GiUqKkruuusuWbt2rRw9elQcDoe0tLRcMIzaw8bpdMrZs2elqKhI/ud//kdSUlIkLCxMTCaTZGZmSnl5ubhcLoW2ki6lM+HT5RsiXC4X8vLy0NjYCIvFgpKSErS2tiI1NdXTZ/DgwYiPj0dRURFSUlJQVFSEYcOGwWQyefqkpaVh3rx5KCsr6/AB37m5uVi2bFlXS6VuoFKpoNPpcNttt2HUqFEoLi7GX//6V+Tn52PXrl0IDg5GQkICvv/97+OGG26A2WyGwWBAQEAAnE4nzp07h+rqanz22WcoKytDTU0N2tra0K9fP2RlZWHatGkYPHgwn0p4Fel0+Bw5cgQWiwXNzc0IDQ3Fxo0bkZiYiNLSUuh0OoSHh3v1N5lMsFqtAACr1eoVPO3t7W0dycnJQXZ2tue7w+FAXFxcZ0unbqBSqRASEoLbb78dY8aMQXV1NT788EPs2bMHJSUl2LFjB86fP4+2tja43W7Pg+nVajW0Wi3CwsLQr18/pKen4/bbb8fIkSMRGRkJtVrNYfSrTKfD58Ybb0RpaSnsdjs2bNiAjIwMFBYW+qM2D71eD71e79d1kG+1Hwn169cP/fr1w7Rp09Dc3AybzYaamhqcO3cO58+fh8vlgkajQWhoKKKiohAbG4vIyEjodDoGzlWu0+Gj0+kwYMAAAEBSUhIOHDiA5557DlOnTkVLSwvq6uq8jn5sNhvMZjMAwGw2o7i42Gt5NpvN00ZXn/bwaA+Y0NBQ3HDDDQpXRT3BFd8c4Xa74XQ6kZSUBK1Wi4KCAk9bRUUFKisrYbFYAAAWiwVHjhxBbW2tp09+fj4MBgMSExOvtBQi6kU6deSTk5ODH//4x4iPj0d9fT3efPNN7N69G++99x6MRiNmzZqF7OxsREREwGAwYMGCBbBYLEhJSQEAjB8/HomJiZg5cyZWrFgBq9WKJUuWIDMzk6dVRNeYToVPbW0t7r33XtTU1MBoNGL48OF47733cMcddwAAnnnmGajVakyePBlOpxNpaWlYvXq1Z/6AgABs2bIF8+bNg8ViQUhICDIyMvDEE0/4dquIqMdTifSAh7J0ksPhgNFohN1uh8FgULocIvq3zuyb/EEMESmC4UNEimD4EJEiGD5EpAiGDxEpguFDRIpg+BCRIhg+RKQIhg8RKYLhQ0SKYPgQkSIYPkSkCIYPESmC4UNEimD4EJEiGD5EpAiGDxEpguFDRIpg+BCRIhg+RKQIhg8RKYLhQ0SKYPgQkSIYPkSkCIYPESmC4UNEimD4EJEiGD5EpAiGDxEpguFDRIpg+BCRIhg+RKQIhg8RKYLhQ0SKYPgQkSKuKHyWL18OlUqFRYsWeaY1NzcjMzMTkZGRCA0NxeTJk2Gz2bzmq6ysRHp6OoKDgxEdHY2HH34YbW1tV1IKEfUyXQ6fAwcO4KWXXsLw4cO9pj/00EN45513kJeXh8LCQlRXV2PSpEmedpfLhfT0dLS0tOCjjz7Ca6+9hnXr1mHp0qVd3woi6n2kC+rr62XgwIGSn58vt912myxcuFBEROrq6kSr1UpeXp6n7yeffCIApKioSEREtm3bJmq1WqxWq6fPmjVrxGAwiNPpvOD6mpubxW63ez5VVVUCQOx2e1fKJyI/sdvtl71vdunIJzMzE+np6UhNTfWaXlJSgtbWVq/pgwcPRnx8PIqKigAARUVFGDZsGEwmk6dPWloaHA4HysrKLri+3NxcGI1GzycuLq4rZRNRD9Lp8Fm/fj0+/vhj5ObmfqfNarVCp9MhPDzca7rJZILVavX0+WbwtLe3t11ITk4O7Ha751NVVdXZsomoh9F0pnNVVRUWLlyI/Px8BAYG+qum79Dr9dDr9d22PiLyv04d+ZSUlKC2thY333wzNBoNNBoNCgsLsWrVKmg0GphMJrS0tKCurs5rPpvNBrPZDAAwm83fGf1q/97eh4iufp0Kn3HjxuHIkSMoLS31fEaOHIkZM2Z4/qzValFQUOCZp6KiApWVlbBYLAAAi8WCI0eOoLa21tMnPz8fBoMBiYmJPtosIurpOnXaFRYWhqFDh3pNCwkJQWRkpGf6rFmzkJ2djYiICBgMBixYsAAWiwUpKSkAgPHjxyMxMREzZ87EihUrYLVasWTJEmRmZvLUiuga0qnwuRzPPPMM1Go1Jk+eDKfTibS0NKxevdrTHhAQgC1btmDevHmwWCwICQlBRkYGnnjiCV+XQkQ9mEpEROkiOsvhcMBoNMJut8NgMChdDhH9W2f2Tf62i4gUwfAhIkUwfIhIEQwfIlIEw4eIFMHwISJFMHyISBEMHyJSBMOHiBTB8CEiRTB8iEgRDB8iUgTDh4gUwfAhIkUwfIhIEQwfIlIEw4eIFMHwISJFMHyISBEMHyJSBMOHiBTB8CEiRTB8iEgRDB8iUgTDh4gUwfAhIkUwfIhIEQwfIlIEw4eIFMHwISJFMHyISBEMHyJSBMOHiBTB8CEiRTB8iEgRDB8iUgTDh4gUoVG6gK4QEQCAw+FQuBIi+qb2fbJ9H72YXhk+Z86cAQDExcUpXAkRXUh9fT2MRuNF+/TK8ImIiAAAVFZWXnIDexqHw4G4uDhUVVXBYDAoXc5lY93dq7fWLSKor69HbGzsJfv2yvBRq7++VGU0GnvVf5hvMhgMvbJ21t29emPdl3tAwAvORKQIhg8RKaJXho9er8fjjz8OvV6vdCmd1ltrZ93dq7fW3RkquZwxMSIiH+uVRz5E1PsxfIhIEQwfIlIEw4eIFMHwISJF9MrwefHFF5GQkIDAwEAkJyejuLhY0Xr27NmDu+++G7GxsVCpVNi0aZNXu4hg6dKliImJQVBQEFJTU/HZZ5959Tl79ixmzJgBg8GA8PBwzJo1Cw0NDX6tOzc3F6NGjUJYWBiio6MxceJEVFRUePVpbm5GZmYmIiMjERoaismTJ8Nms3n1qaysRHp6OoKDgxEdHY2HH34YbW1tfqt7zZo1GD58uOfuX4vFgu3bt/fomi9k+fLlUKlUWLRoUa+r3Sekl1m/fr3odDr5y1/+ImVlZTJ79mwJDw8Xm82mWE3btm2T3/zmN/K///u/AkA2btzo1b58+XIxGo2yadMm+ec//yk/+clPpH///tLU1OTpM2HCBBkxYoTs27dPPvjgAxkwYIBMnz7dr3WnpaXJq6++KkePHpXS0lK58847JT4+XhoaGjx95s6dK3FxcVJQUCAHDx6UlJQUGT16tKe9ra1Nhg4dKqmpqXLo0CHZtm2bREVFSU5Ojt/q3rx5s2zdulU+/fRTqaiokEcffVS0Wq0cPXq0x9b8bcXFxZKQkCDDhw+XhQsXeqb3htp9pdeFzy233CKZmZme7y6XS2JjYyU3N1fBqv7ft8PH7XaL2WyWlStXeqbV1dWJXq+Xt956S0REysvLBYAcOHDA02f79u2iUqnk1KlT3VZ7bW2tAJDCwkJPnVqtVvLy8jx9PvnkEwEgRUVFIvJ18KrVarFarZ4+a9asEYPBIE6ns9tq79Onj7zyyiu9oub6+noZOHCg5Ofny2233eYJn95Quy/1qtOulpYWlJSUIDU11TNNrVYjNTUVRUVFClbWsePHj8NqtXrVbDQakZyc7Km5qKgI4eHhGDlypKdPamoq1Go19u/f32212u12AP//1ICSkhK0trZ61T548GDEx8d71T5s2DCYTCZPn7S0NDgcDpSVlfm9ZpfLhfXr16OxsREWi6VX1JyZmYn09HSvGoHe8fftS73qV+1fffUVXC6X1188AJhMJhw7dkyhqi7OarUCwAVrbm+zWq2Ijo72atdoNIiIiPD08Te3241FixZhzJgxGDp0qKcunU6H8PDwi9Z+oW1rb/OXI0eOwGKxoLm5GaGhodi4cSMSExNRWlraY2sGgPXr1+Pjjz/GgQMHvtPWk/++/aFXhQ/5T2ZmJo4ePYoPP/xQ6VIuy4033ojS0lLY7XZs2LABGRkZKCwsVLqsi6qqqsLChQuRn5+PwMBApctRXK867YqKikJAQMB3rv7bbDaYzWaFqrq49rouVrPZbEZtba1Xe1tbG86ePdst25WVlYUtW7Zg165duO666zzTzWYzWlpaUFdXd9HaL7Rt7W3+otPpMGDAACQlJSE3NxcjRozAc88916NrLikpQW1tLW6++WZoNBpoNBoUFhZi1apV0Gg0MJlMPbZ2f+hV4aPT6ZCUlISCggLPNLfbjYKCAlgsFgUr61j//v1hNpu9anY4HNi/f7+nZovFgrq6OpSUlHj67Ny5E263G8nJyX6rTUSQlZWFjRs3YufOnejfv79Xe1JSErRarVftFRUVqKys9Kr9yJEjXuGZn58Pg8GAxMREv9X+bW63G06ns0fXPG7cOBw5cgSlpaWez8iRIzFjxgzPn3tq7X6h9BXvzlq/fr3o9XpZt26dlJeXy5w5cyQ8PNzr6n93q6+vl0OHDsmhQ4cEgDz99NNy6NAhOXnypIh8PdQeHh4u//jHP+Tw4cNyzz33XHCo/aabbpL9+/fLhx9+KAMHDvT7UPu8efPEaDTK7t27paamxvM5f/68p8/cuXMlPj5edu7cKQcPHhSLxSIWi8XT3j70O378eCktLZV3331X+vbt69eh38WLF0thYaEcP35cDh8+LIsXLxaVSiU7duzosTV35JujXb2t9ivV68JHROT555+X+Ph40el0csstt8i+ffsUrWfXrl0C4DufjIwMEfl6uP2xxx4Tk8kker1exo0bJxUVFV7LOHPmjEyfPl1CQ0PFYDDI/fffL/X19X6t+0I1A5BXX33V06epqUnmz58vffr0keDgYPnpT38qNTU1Xss5ceKE/PjHP5agoCCJioqSX/3qV9La2uq3uh944AHp16+f6HQ66du3r4wbN84TPD215o58O3x6U+1Xis/zISJF9KprPkR09WD4EJEiGD5EpAiGDxEpguFDRIpg+BCRIhg+RKQIhg8RKYLhQ0SKYPgQkSIYPkSkiP8D+Z5Lt0KiNYMAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "62.642608642578125"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "play(True)[-1]"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第9章-策略梯度算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
